import os
import json
import habitat_sim
import numpy as np
import networkx as nx
from tqdm import tqdm
from hovsg.data.hm3dsem.habitat_utils import make_cfg_mp3d

def load_nav_graph(connectivity_dir, scan):
    ''' Load connectivity graph for each scan '''

    def distance(pose1, pose2):
        ''' Euclidean distance between two graph poses '''
        return ((pose1['pose'][3]-pose2['pose'][3])**2\
            + (pose1['pose'][7]-pose2['pose'][7])**2\
            + (pose1['pose'][11]-pose2['pose'][11])**2)**0.5

    with open(os.path.join(connectivity_dir, '%s_connectivity.json' % scan)) as f:
        G = nx.Graph()
        positions = {}
        data = json.load(f)
        for i,item in enumerate(data):
            if item['included']:
                for j,conn in enumerate(item['unobstructed']):
                    if conn and data[j]['included']:
                        positions[item['image_id']] = np.array([item['pose'][3],
                                item['pose'][7], item['pose'][11]])
                        assert data[j]['unobstructed'][i], 'Graph should be undirected'
                        G.add_edge(item['image_id'],data[j]['image_id'],weight=distance(item,data[j]))
        nx.set_node_attributes(G, values=positions, name='position')
    
    return G

def get_connectivity_map(scan, vp2pos):
    G = load_nav_graph("../VLN-DUET/datasets/R2R/connectivity", scan)
    
    nodes = list(G.nodes())
    nodes2id = {node: i for i, node in enumerate(nodes)}

    edges = list(G.edges())
    nodes = [node for node in nodes if node in vp2pos]
    edges = [(u, v) for u, v in edges if u in vp2pos and v in vp2pos]
    connectivity_map = {nodes2id[node]: [round(x, 2) for x in vp2pos[node]] for node in nodes}
    return connectivity_map, nodes2id

def get_vp2pos(scan):
    with open(f"../HOV-SG/vp2pos/vp2pos_{scan}.json", 'r') as f:
        vp2pos = json.load(f)
    return vp2pos

def get_scan_objects_info(scan):
    root_dataset_dir = "../scene_datasets/mp3d"
    scene_data_dir = f"{root_dataset_dir}/{scan}/"
    scene_mesh = os.path.join(scene_data_dir, scan + ".glb")

    sim_settings = {
        "scene": scene_mesh,
        "default_agent": 0,
        "sensor_height": 1.5,
        "color_sensor": True,
        "depth_sensor": True,
        "semantic_sensor": True,
        "lidar_sensor": False,
        "move_forward": 0.2,
        "move_backward": 0.2,
        "turn_left": 5,
        "turn_right": 5,
        "look_up": 5,
        "look_down": 5,
        "look_left": 5,
        "look_right": 5,
        "width": 1080,
        "height": 720,
        "enable_physics": False,
        "seed": 42,
        "lidar_fov": 360,
        "depth_img_for_lidar_n": 20,
    }
    os.environ["MAGNUM_LOG"] = "quiet"
    os.environ["HABITAT_SIM_LOG"] = "quiet"

    sim_cfg = make_cfg_mp3d(sim_settings, root_dataset_dir, scene_data_dir, scan, False)
    sim = habitat_sim.Simulator(sim_cfg)

    semantic_scene = sim.semantic_scene
    obj_infos = []
    for obj in semantic_scene.objects:
        obj_info = {
            "id": obj.id,
            "position": [round(x, 2) for x in obj.aabb.center.tolist()],
            # "bounding_box": obj.aabb.sizes.tolist(),
        }
        obj_infos.append(obj_info)
    
    sim.close()

    return obj_infos

def generate_user_content(original_info, lm_info, connectivity_map, start_pos, objs_info, node2id):
    instruction = original_info['instruction']

    all_landmarks = original_info['landmarks']
    landmarks, target = all_landmarks.split('\n')[:-1], all_landmarks.split('\n')[-1]
    landmarks =  [x.split('.')[1].strip() for x in landmarks]
    target = target.split(':')[1].strip().split('(')[0].strip()
    
    landmark_names = [x.split('(')[0].strip() for x in landmarks]
    landmark_types = [x.split('(')[1].split(')')[0].strip() for x in landmarks]

    prompt = f"""1. 'Instruction': {instruction}
2. 'Candidates':
"""
    for idx, (lm, lm_type) in enumerate(zip(landmark_names, landmark_types)):
        cands = lm_info[str(idx)]
        prompt += f"Landmark: {lm} ({lm_type})\n"
        if lm_type.lower() == 'floor':
            prompt += f"This landmark is a floor which need to be inferred from the connectivity map. Note that floor index starts from 1.\n"
        elif lm_type.lower() == 'room':
            for i, cand in enumerate(cands):
                prompt += f"Candidate {i+1}: <id: {node2id[cand]}, position: {connectivity_map[node2id[cand]]}>\n"
        elif lm_type.lower() == 'object':
            for i, cand in enumerate(cands):
                cand_info = objs_info[int(cand)]
                prompt += f"Candidate {i+1}: {cand_info}\n"
            prompt += '\n'
    
    prompt += f"""3. 'Target': {target}
4. 'Connectivity Map': {connectivity_map}
5. 'Start Position': {start_pos}
""" 
    return prompt

system_prompt = """[Task Background]
You are an advanced 3D environment understanding assistant. Your main objective is to interpret a language-based instruction describing an indoor environment and identify which candidate landmark best matches the specified target.

[Input Definitions]
1. 'Instruction': A natural language description involves spatial relationships (e.g., relative positions, distances) among landmarks.
2. 'Candidates': A list of candidates for all the landmarks mentioned in the instruction. Each landmark is one of the following:
    - Floor: No explicit candidate data is given. You must infer the positions belonging to this floor from the connectivity map.
    - Room: The candidates are a list of nodes in the connectivity map.
    - Object: Includes the object's unique identifier and the 3D coordinates of its center.
3. 'Target': The specific landmark name within the instruction that you must locate among the given candidates.
4. 'Connectivity Map': A representation of the environment's layout, including a list of key positions.
5. 'Start Position': The agent's initial 3D location, which may be referenced in the instruction.

[Coordinate System]
All 3D positions (x, y, z) follows the convention:
- x-axis: Left to right, increasing to the right.
- y-axis: Floor to ceiling, increasing upward.
- z-axis: Front to back, increasing forward.

[Output Requirements]
Analyze the provided information to decide which candidate is the correct match for the 'Target Landmark'. Consider all clues from the natural language description—particularly any spatial relationships—and compare them with the bounding boxes and 3D positions of the candidates.
Your output must identify a single candidate as the correct match. Format your answer as:
    'The correct candidate is <Candidate_ID>.'
"""

split = "val_unseen"
data_path = f"../HOV-SG/node_generation_gt/full_candidates"
files = os.listdir(data_path)

original_data = f"../HOV-SG/train_data_prepare/REVERIE_landmarks/{split}_objs.json"
with open(original_data, 'r') as f:
    original_data = json.load(f)

with open(f"../VLN-DUET/datasets/REVERIE/annotations/REVERIE_{split}_enc.json", 'r') as f:
    reverie_data = json.load(f)
    reverie_data = {x['id']: x for x in reverie_data}


for file in tqdm(files):
    file_path = os.path.join(data_path, file)
    with open(file_path, 'r') as f:
        data = json.load(f)
    
    train_data = []
    for scan, scan_info in data.items():
        scan_objects_info = get_scan_objects_info(scan)
        vp2pos = get_vp2pos(scan)
        connectivity_map, node2id = get_connectivity_map(scan, vp2pos)

        for inst_id, lm_info in scan_info.items():
            template = {
                "sample_id": inst_id,
                "messages": [],
                "system": system_prompt
            }
            
            try:
                start_pos = [round(x, 2) for x in vp2pos[reverie_data[inst_id[:-2]]['path'][0]]]
                user_content = generate_user_content(original_data[scan][inst_id], lm_info, connectivity_map, start_pos, scan_objects_info, node2id)
                gt_target = scan_objects_info[int(inst_id.split('_')[1])]['id']
                assistant_content = f"The correct candidate is {gt_target}."
                template["messages"].append({"role": "user", "content": user_content})
                template["messages"].append({"role": "assistant", "content": assistant_content})
            except:
                print(scan, inst_id)
                continue
            
            train_data.append(template)
    print(f"Processed {file}, total samples: {len(train_data)}")
    os.makedirs("../HOV-SG/node_generation_gt/llm_input", exist_ok=True)
    with open(f"../HOV-SG/node_generation_gt/llm_input/{file}", 'w') as f:
        json.dump(train_data, f, indent=4)
